{
 "cells": [
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/pinecone-io/examples/blob/master/learn/search/faiss-ebook/product-quantization/pq_faiss.ipynb) [![Open nbviewer](https://raw.githubusercontent.com/pinecone-io/examples/master/assets/nbviewer-shield.svg)](https://nbviewer.org/github/pinecone-io/examples/blob/master/learn/search/faiss-ebook/product-quantization/pq_faiss.ipynb)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# PQ and Faiss\n",
    "\n",
    "In this notebook we will cover the product quantization (PQ) implementation in Faiss. First, we need to load our Sift1M dataset."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "# now define a function to read the fvecs file format of Sift1M dataset\n",
    "def read_fvecs(fp):\n",
    "    a = np.fromfile(fp, dtype='int32')\n",
    "    d = a[0]\n",
    "    return a.reshape(-1, d + 1)[:, 1:].copy().view('float32')\n",
    "\n",
    "# 1M samples, cut down to 500K\n",
    "xb = read_fvecs('../../../data/sift/sift_base.fvecs')[:500_000]\n",
    "# queries\n",
    "xq = read_fvecs('../../../data/sift/sift_query.fvecs')[0].reshape(1, -1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Initialize an index in Faiss using *only* PQ."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import faiss\n",
    "\n",
    "D = xb.shape[1]\n",
    "m = 8\n",
    "assert D % m == 0\n",
    "nbits = 8  # number of bits per subquantizer, k* = 2**nbits\n",
    "index = faiss.IndexPQ(D, m, nbits)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We have our cluster centroids (*reproduction values*) that must be optimized to our dataset, therefore if we check if the `index` `is_trained` we will see it is *not*."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "False"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "index.is_trained"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "So we must train it using our vectors `xb`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "index.train(xb)  # PQ training can take some time when using large nbits"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "And now we can `add` our vectors and `search`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "index.add(xb)  # this is also very slow for large nbits\n",
    "\n",
    "k = 100  # return top k results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 22.7 ms, sys: 13.2 ms, total: 35.9 ms\n",
      "Wall time: 9.37 ms\n"
     ]
    }
   ],
   "source": [
    "dist, I = index.search(xq, k)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1.49 ms ± 49.1 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)\n"
     ]
    }
   ],
   "source": [
    "%%timeit\n",
    "index.search(xq, k)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Search time is nothing special, PQ alone is still an exhaustive search so we would expect nothing spectacular here - but we can make it fast as we'll see later.\n",
    "\n",
    "Let's compare our results against those produced by a non-quantized flat index."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "l2_index = faiss.IndexFlatL2(D)\n",
    "l2_index.add(xb)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 46.1 ms, sys: 15.1 ms, total: 61.2 ms\n",
      "Wall time: 15 ms\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "l2_dist, l2_I = l2_index.search(xq, k)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "50"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "sum([1 for i in I[0] if i in l2_I])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "A recall of 50%, not cutting-edge, but a reasonable sacrifice if this allows us to search larger datasets. Let's see if PQ has made good on it's promise of reduced memory usage."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "def get_memory(index):\n",
    "    # write index to file\n",
    "    faiss.write_index(index, './temp.index')\n",
    "    # get file size\n",
    "    file_size = os.path.getsize('./temp.index')\n",
    "    # delete saved index\n",
    "    os.remove('./temp.index')\n",
    "    return file_size"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "256000045"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# first we check our Flat index, this should be larger\n",
    "get_memory(l2_index)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "4131158"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# now our PQ index\n",
    "get_memory(index)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We've reduced our index size from a cool **256MB**, to just **4.1MB** - that is pretty impressive.\n",
    "\n",
    "But what about this not so great search-speed? Is there anything we can do about that? Fortunately yes! We can improve search-speed by using another quantization step - we add a *coarse* quantizer, `IndexIVF` to the process.\n",
    "\n",
    "What we will now have is a *coarse quantizer* (IVF) which restricts the *scope of our search* to the most relevant samples - followed by our *fine quantizer* (PQ) which compresses the high-memory usage of our stored vectors by *restricting the scope of our vector**space***."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2**nbits=256\n"
     ]
    }
   ],
   "source": [
    "vecs = faiss.IndexFlatL2(D)\n",
    "\n",
    "nlist = 2048  # how many Voronoi cells (must be >= k* which is 2**nbits)\n",
    "nbits = 8  # when using IVF+PQ, higher nbits values are not supported\n",
    "index = faiss.IndexIVFPQ(vecs, D, nlist, m, nbits)\n",
    "print(f\"{2**nbits=}\")  # our value for nlist"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "False"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "index.is_trained"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Again we must `train` our index, and again this can take awhile."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "index.train(xb)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now we `add` our vectors."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "index.add(xb)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "And now we're ready to test our new IVF -> PQ index!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "dist, I = index.search(xq, k)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "86.3 µs ± 15 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)\n"
     ]
    }
   ],
   "source": [
    "%%timeit\n",
    "index.search(xq, k)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "*Lightning fast search...* But how is the recall?"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "34"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "sum([1 for i in I[0] if i in l2_I])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Our recall limit will be roughly what we reached when using `IndexPQ` alone, so we have hit 36/57, or about 60% of the potential recall performance. Now, because we are now using an IVF index, we can push our recall towards these higher values by increasing the number of cells to search using `nprobe`. Although it is worth noting that in this case, as there is always some noise in training the index, the max recall achievable is 43%."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [],
   "source": [
    "index.nprobe = 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [],
   "source": [
    "dist, I = index.search(xq, k)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "39"
      ]
     },
     "execution_count": 23,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "sum([1 for i in I[0] if i in l2_I])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "52"
      ]
     },
     "execution_count": 24,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "index.nprobe = 48\n",
    "dist, I = index.search(xq, k)\n",
    "sum([1 for i in I[0] if i in l2_I])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "52%, getting better... Now, if we were to try pushing the `nprobe` value up signficantly, we won't return any recall results better than 52%, this is our best-case performance in terms of balancing recall and search-time."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "52"
      ]
     },
     "execution_count": 25,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "index.nprobe = 2048  # this is searching all Voronoi cells\n",
    "dist, I = index.search(xq, k)\n",
    "sum([1 for i in I[0] if i in l2_I])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "So it is better that we stick with `nprobe` of `48`. Let's take a look at the speed difference between this index and a flat L2 index."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [],
   "source": [
    "index.nprobe = 48"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "134 µs ± 4.89 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)\n"
     ]
    }
   ],
   "source": [
    "%%timeit\n",
    "dist, I = index.search(xq, k)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "8.42 ms ± 344 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n"
     ]
    }
   ],
   "source": [
    "%%timeit\n",
    "l2_dist, l2_I = l2_index.search(xq, k)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "A significant speed increase from 8.42ms to 132µs, and what are the differences in memory usage?"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "IVFPQ: 9196212\n",
      "Flat: 256000045\n"
     ]
    }
   ],
   "source": [
    "print(f\"IVFPQ: {get_memory(index)}\\nFlat: {get_memory(l2_index)}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "So our final stats are:\n",
    "\n",
    "| | FlatL2 | PQ | IVFPQ | Note |\n",
    "| --- | --- | --- | --- | --- |\n",
    "| Recall | 100% | 50% | 52% | *IVFPQ and PQ can attain equal performance with equivalent parameters - improved recall with greater `nbits`* |\n",
    "| Speed | 8.26ms | 1.49ms | 0.29ms | |\n",
    "| Memory | 256MB | 6.5MB | 9.2MB | |"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "base",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.13 (default, Mar 28 2022, 06:59:08) [MSC v.1916 64 bit (AMD64)]"
  },
  "orig_nbformat": 4,
  "vscode": {
   "interpreter": {
    "hash": "5fe10bf018ef3e697f9035d60bf60847932a12bface18908407fd371fe880db9"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
